import numpy as np

from autograd.builtins import NamedTupleVSpace
from autograd.extend import VSpace


class ArrayVSpace(VSpace):
    def __init__(self, value):
        value = np.asarray(value)
        self.shape = value.shape
        self.dtype = value.dtype

    @property
    def size(self):
        return np.prod(self.shape)

    @property
    def ndim(self):
        return len(self.shape)

    def zeros(self):
        return np.zeros(self.shape, dtype=self.dtype)

    def ones(self):
        return np.ones(self.shape, dtype=self.dtype)

    def standard_basis(self):
        for idxs in np.ndindex(*self.shape):
            vect = np.zeros(self.shape, dtype=self.dtype)
            vect[idxs] = 1
            yield vect

    def randn(self):
        return np.array(np.random.randn(*self.shape)).astype(self.dtype)

    def _inner_prod(self, x, y):
        return np.dot(np.ravel(x), np.ravel(y))


class ComplexArrayVSpace(ArrayVSpace):
    iscomplex = True

    @property
    def size(self):
        return np.prod(self.shape) * 2

    def ones(self):
        return np.ones(self.shape, dtype=self.dtype) + 1.0j * np.ones(self.shape, dtype=self.dtype)

    def standard_basis(self):
        for idxs in np.ndindex(*self.shape):
            for v in [1.0, 1.0j]:
                vect = np.zeros(self.shape, dtype=self.dtype)
                vect[idxs] = v
                yield vect

    def randn(self):
        return np.array(np.random.randn(*self.shape)).astype(self.dtype) + 1.0j * np.array(
            np.random.randn(*self.shape)
        ).astype(self.dtype)

    def _inner_prod(self, x, y):
        return np.real(np.dot(np.conj(np.ravel(x)), np.ravel(y)))

    def _covector(self, x):
        return np.conj(x)


VSpace.register(np.ndarray, lambda x: ComplexArrayVSpace(x) if np.iscomplexobj(x) else ArrayVSpace(x))

for type_ in [float, np.longdouble, np.float64, np.float32, np.float16]:
    ArrayVSpace.register(type_)

for type_ in [complex, np.clongdouble, np.complex64, np.complex128]:
    ComplexArrayVSpace.register(type_)


if np.lib.NumpyVersion(np.__version__) >= "2.0.0":

    class EigResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.EigResult

    class EighResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.EighResult

    class QRResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.QRResult

    class SlogdetResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.SlogdetResult

    class SVDResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg._linalg.SVDResult

    EigResultVSpace.register(np.linalg._linalg.EigResult)
    EighResultVSpace.register(np.linalg._linalg.EighResult)
    QRResultVSpace.register(np.linalg._linalg.QRResult)
    SlogdetResultVSpace.register(np.linalg._linalg.SlogdetResult)
    SVDResultVSpace.register(np.linalg._linalg.SVDResult)
elif np.__version__ >= "1.25":

    class EigResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.EigResult

    class EighResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.EighResult

    class QRResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.QRResult

    class SlogdetResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.SlogdetResult

    class SVDResultVSpace(NamedTupleVSpace):
        seq_type = np.linalg.linalg.SVDResult

    EigResultVSpace.register(np.linalg.linalg.EigResult)
    EighResultVSpace.register(np.linalg.linalg.EighResult)
    QRResultVSpace.register(np.linalg.linalg.QRResult)
    SlogdetResultVSpace.register(np.linalg.linalg.SlogdetResult)
    SVDResultVSpace.register(np.linalg.linalg.SVDResult)
